import syscall
import os
import time
import fs
import mem
import fmt
import libc
import unsafe
import io
import io.buf

fn logf(string format, ...[any] args) {
    var msg = fmt.sprintf(format, ...args)
    var t = time.now()
    fmt.printf('%v - %v\n', t.ms_timestamp(), msg)
}


type config_t = struct{
    i32 dim
    i32 hidden_dim
    i32 n_layers
    i32 n_heads
    i32 n_kv_heads
    i32 vocab_size
    i32 seq_len
}

type prob_index_t = struct{
    f32 prob
    int index
}

type sampler_t = struct{
    int vocab_size
    vec<prob_index_t> probindex
    f32 temperature
    f32 topp
    u64 rng_state
}

type token_index_t = struct{
    string str
    int id
}

type tokenizer_t = struct{
    [string] vocab
    [f32] vocab_scores
    [token_index_t] sorted_vocab
    int vocab_size
    u32 max_token_length
    [u8;512] byte_pieces
}


type transformer_weights = struct{
    // token embedding table
    rawptr<f32> token_embedding_table    // (vocab_size, dim)

    // weights for rmsnorms
    rawptr<f32> rms_att_weight           // (layer, dim)
    rawptr<f32> rms_ffn_weight           // (layer, dim)

    // weights for matmuls. note dim == n_heads * head_size
    rawptr<f32> wq                       // (layer, dim, n_heads * head_size)
    rawptr<f32> wk                       // (layer, dim, n_kv_heads * head_size)
    rawptr<f32> wv                       // (layer, dim, n_kv_heads * head_size)
    rawptr<f32> wo                       // (layer, n_heads * head_size, dim)

    // weights for ffn
    rawptr<f32> w1                       // (layer, hidden_dim, dim)
    rawptr<f32> w2                       // (layer, dim, hidden_dim)
    rawptr<f32> w3                       // (layer, hidden_dim, dim)

    // final rmsnorm
    rawptr<f32> rms_final_weight         // (dim,)

    // (optional) classifier weights for the logits, on the last layer
    rawptr<f32> wcls
}

type run_state_t = struct{
    // current wave of activations
    vec<f32> x        // activation at current time stamp (dim,)
    vec<f32> xb       // same, but inside a residual branch (dim,)
    vec<f32> xb2      // an additional buffer just for convenience (dim,)
    vec<f32> hb       // buffer for hidden dimension in the ffn (hidden_dim,)
    vec<f32> hb2      // buffer for hidden dimension in the ffn (hidden_dim,)
    vec<f32> q        // query (dim,)
    vec<f32> k        // key (dim,)
    vec<f32> v        // value (dim,)
    vec<f32> att      // buffer for scores/attention values (n_heads, seq_len)
    vec<f32> logits   // output logits
    // kv cache
    vec<f32> key_cache    // (layer, seq_len, dim)
    vec<f32> value_cache  // (layer, seq_len, dim)
}

type transformer = struct{
    config_t config
    transformer_weights weights
    run_state_t state
    int fd
    ptr<f32> data
    uint file_size
}

fn transformer.malloc_run_state(rawptr<run_state_t> s, rawptr<config_t> cfg):void! {
    i32 kv_dim = (cfg.dim * cfg.n_kv_heads) / cfg.n_heads

    // allocate memory for various buffers
    s.x = vec_new<f32>(0, cfg.dim as int)
    s.xb = vec_new<f32>(0, cfg.dim as int)
    s.xb2 = vec_new<f32>(0, cfg.dim as int)
    s.hb = vec_new<f32>(0, cfg.hidden_dim as int)
    s.hb2 = vec_new<f32>(0, cfg.hidden_dim as int)

    // allocate memory for queries, keys, and values
    s.q = vec_new<f32>(0, cfg.dim as int)
    s.k = vec_new<f32>(0, kv_dim as int)
    s.v = vec_new<f32>(0, kv_dim as int)

    s.att = vec_new<f32>(0, (cfg.n_heads * cfg.seq_len) as int)
    s.logits = vec_new<f32>(0, cfg.vocab_size as int)

    s.key_cache = vec_new<f32>(0, (cfg.n_layers * cfg.seq_len * kv_dim) as int)
    s.value_cache = vec_new<f32>(0, (cfg.n_layers * cfg.seq_len * kv_dim) as int)
}

fn transformer.memory_map_weights(anyptr weights_ptr, u8 shared_weights):void! {
    config_t cfg = self.config
    i32 head_size = cfg.dim / cfg.n_heads

    // 使用 unsigned long long 来避免大模型的整数溢出问题
    i32 n_layers = cfg.n_layers

    self.weights.token_embedding_table = weights_ptr as rawptr<f32>
    weights_ptr += (cfg.vocab_size * cfg.dim * @sizeof(f32)) as anyptr

    self.weights.rms_att_weight = weights_ptr as rawptr<f32>
    weights_ptr += (n_layers * cfg.dim * @sizeof(f32)) as anyptr

    self.weights.wq = weights_ptr as rawptr<f32>
    weights_ptr += (n_layers * cfg.dim * (cfg.n_heads * head_size) * @sizeof(f32)) as anyptr

    self.weights.wk = weights_ptr as rawptr<f32>
    weights_ptr += (n_layers * cfg.dim * (cfg.n_kv_heads * head_size) * @sizeof(f32)) as anyptr

    self.weights.wv = weights_ptr as rawptr<f32>
    weights_ptr += (n_layers * cfg.dim * (cfg.n_kv_heads * head_size) * @sizeof(f32)) as anyptr

    self.weights.wo = weights_ptr as rawptr<f32>
    weights_ptr += (n_layers * (cfg.n_heads * head_size) * cfg.dim * @sizeof(f32)) as anyptr

    self.weights.rms_ffn_weight = weights_ptr as rawptr<f32>
    weights_ptr += (n_layers * cfg.dim * @sizeof(f32)) as anyptr

    self.weights.w1 = weights_ptr as rawptr<f32>
    weights_ptr += (n_layers * cfg.dim * cfg.hidden_dim * @sizeof(f32)) as anyptr

    self.weights.w2 = weights_ptr as rawptr<f32>
    weights_ptr += (n_layers * cfg.hidden_dim * cfg.dim * @sizeof(f32)) as anyptr

    self.weights.w3 = weights_ptr as rawptr<f32>
    weights_ptr += (n_layers * cfg.dim * cfg.hidden_dim * @sizeof(f32)) as anyptr

    self.weights.rms_final_weight = weights_ptr as rawptr<f32>
    weights_ptr += (cfg.dim * @sizeof(f32)) as anyptr

    // 跳过 RoPE 频率表（在非量化版本中不需要）
    weights_ptr += (cfg.seq_len * head_size / 2 * @sizeof(f32)) as anyptr // 跳过 freq_cis_real
    weights_ptr += (cfg.seq_len * head_size / 2 * @sizeof(f32)) as anyptr // 跳过 freq_cis_imag

    // 如果共享分类器权重，则使用 token_embedding_table
    if shared_weights > 0 {
        self.weights.wcls = self.weights.token_embedding_table
    } else {
        self.weights.wcls = weights_ptr as rawptr<f32>
    }
}

fn transformer.rmsnorm([f32] o, [f32] x, anyptr weight, int size) {
    f32 ss = 0.0
    for int j = 0; j < size; j += 1 {
        ss += x[j] * x[j]
    }

    ss /= size as f32
    ss += 1e-5
    ss = 1.0 / libc.sqrtf(ss)

    for int j = 0; j < size; j += 1 {
        f32 w = *((weight + j as anyptr * @sizeof(f32)) as rawptr<f32>)
        o[j] = w * (ss * x[j])
    }
}

fn transformer.matmul(anyptr xout_ref, anyptr x_ref, anyptr w, int n, int d):void! {
    // W (d,n) @ x (n,) -> xout (d,)
    // anyptr x_ref = x.ref()
    // anyptr xout_ref = xout.ref()

    // TODO #pragma optimize
    for int i = 0; i < d; i+=1 {
        f32 val = 0.0
        for int j = 0; j < n; j+=1 {
            f32 w_val = *((w + ((i * n + j) * 4) as anyptr) as rawptr<f32>)
            f32 x_val = *((x_ref + (j * 4) as anyptr) as rawptr<f32>)
            val += w_val * x_val
        }


        // xout[i] = val
        *((xout_ref + (i * 4) as anyptr) as rawptr<f32>) = val
    }

    // log('[transformer.matmul] end')
}

fn softmax([f32] x, int size) {
    f32 max_val = x[0]
    for int i = 1; i < size; i += 1 {
        if x[i] > max_val {
            max_val = x[i]
        }
    }

    // exp and sum
    f32 sum = 0.0
    for int i = 0; i < size; i += 1 {
        x[i] = libc.expf(x[i] - max_val)
        sum += x[i]
    }

    // normalize
    for int i = 0; i < size; i += 1 {
        x[i] /= sum
    }
}

fn transformer.forward(int token, int pos):[f32]! {
    rawptr<config_t> p = &self.config
    rawptr<transformer_weights> w = &self.weights
    rawptr<run_state_t> s = &self.state

    int dim = p.dim as int
    int kv_dim = ((p.dim * p.n_kv_heads) / p.n_heads) as int
    int kv_mul = (p.n_heads / p.n_kv_heads) as int // integer multiplier of the kv sharing in multiquery
    int hidden_dim = p.hidden_dim as int
    int head_size = dim / p.n_heads as int

    // logf('[forward] start forward all the layers, xb.len=%d', s.xb.len())

    // 复制 token embedding 到 x(首先为 x ref 分配足够的空间)
    s.x = vec_new<f32>(0, dim)
    anyptr content_row = w.token_embedding_table as anyptr + token as anyptr * dim as anyptr * @sizeof(f32)
    libc.memcpy(s.x.ref(), content_row, dim as u64 * @sizeof(f32)) // 直接进行不安全k的数据 copy
    [f32] x = s.x

    // logf('token %d, forward will handle layers, pos %d s.xb2: %f, %f, %f, %f', token, pos, s.xb2[0], s.xb2[1], s.xb2[2], s.xb2[3])
    // logf('forward will handle layers, pos %d s.x: %f, %f, %f, %f', pos, s.x[0], s.x[1], s.x[2], s.x[3])


    // forward all the layers
    for int l = 0; l < p.n_layers as int; l+=1 {
        // attention rmsnorm
        self.rmsnorm(s.xb, x, w.rms_att_weight as anyptr + (l * dim * @sizeof(f32)) as anyptr, dim)

        // logf('forward will handle layers %d, pos %d, s.xb: %f, %f, %f, %f', l, pos, s.x[0], s.x[1], s.x[2], s.x[3])


        // key 和 value 指向 kv cache
        int loff = l * p.seq_len as int * kv_dim // kv cache layer offset for convenience
        int start = loff + pos * kv_dim
        s.k = s.key_cache.slice(start, start + kv_dim)
        s.v = s.value_cache.slice(start, start + kv_dim)

        // qkv matmuls for this position
        self.matmul(s.q.ref(), s.xb.ref(), w.wq as anyptr + (l * dim * dim * @sizeof(f32)) as anyptr, dim, dim)
        self.matmul(s.k.ref(), s.xb.ref(), w.wk as anyptr + (l * dim * kv_dim * @sizeof(f32)) as anyptr, dim, kv_dim)
        self.matmul(s.v.ref(), s.xb.ref(), w.wv as anyptr + (l * dim * kv_dim * @sizeof(f32)) as anyptr, dim, kv_dim)

        // RoPE relative positional encoding: complex-valued rotate q and k in each head
        for int i = 0; i < dim; i+=2 {
            int head_dim = i % head_size
            f32 freq = 1.0 / libc.powf(10000.0, head_dim as f32 / head_size as f32)
            f32 val = pos as f32 * freq
            f32 fcr = libc.cosf(val)
            f32 fci = libc.sinf(val)
            int rotn = 1
            if i < kv_dim {
                rotn = 2
            }
            for int v = 0; v < rotn; v+=1 {
                [f32] list = s.k // the vector to rotate (query or key)
                if v == 0 {
                    list = s.q
                }
                f32 v0 = list[i]
                f32 v1 = list[i+1]
                list[i] = v0 * fcr - v1 * fci
                list[i+1] = v0 * fci + v1 * fcr
            }
        }

        for int h_idx = 0; h_idx < p.n_heads as int; h_idx+=1 {
            // get the query vector for this head
            [f32] q = s.q.slice(h_idx * head_size, (h_idx + 1) * head_size)
            [f32] att = s.att.slice(h_idx * p.seq_len as int, (h_idx + 1) * p.seq_len as int)

            // iterate over all timesteps, including the current one
            for int t = 0; t <= pos; t += 1 {
                // get the key vector for this head and at this timestep
                int k_start = loff + t * kv_dim + (h_idx / kv_mul) * head_size
                [f32] k = s.key_cache.slice(k_start, k_start + head_size)

                // calculate the attention score as the dot product of q and k
                f32 score = 0.0
                for int i = 0; i < head_size; i += 1 {
                    score += q[i] * k[i]
                }
                score /= libc.sqrtf(head_size as f32)
                // save the score to the attention buffer
                att[t] = score
            }

            // softmax the scores to get attention weights, from 0..pos inclusively
            softmax(att, pos + 1)

            // weighted sum of the values, store back into xb
            [f32] xb = s.xb.slice(h_idx * head_size, (h_idx + 1) * head_size)
            libc.memset(xb.ref(), 0, head_size as u64 * @sizeof(f32))

            for int t = 0; t <= pos; t+=1 {
                // get the value vector for this head and at this timestep
                int v_start = loff + t * kv_dim + (h_idx / kv_mul) * head_size
                [f32] v = s.value_cache.slice(v_start, v_start + head_size)
                // get the attention weight for this timestep
                f32 a = att[t]

                // accumulate the weighted value into xb
                for int i = 0; i < head_size; i+=1 {
                    xb[i] += a * v[i]
                }
            }
        }

        // println('loop h end, will set x from s.xb2', l, x[0], x[1], x[2], x[3])
        // println('loop h end, s.xb2', l, s.xb2[0], s.xb2[1], s.xb2[2], s.xb2[3])
        // println('loop h end, s.xb', l, s.xb[0], s.xb[1], s.xb[2], s.xb[3])

        // final matmul to get the output of the attention
        self.matmul(s.xb2.ref(), s.xb.ref(), w.wo as anyptr + (l * dim * dim * @sizeof(f32)) as anyptr, dim, dim)

        // println('loop h end after, s.xb2', l, s.xb2[0], s.xb2[1], s.xb2[2], s.xb2[3])

        // residual connection back into x
        for int i = 0; i < dim; i+=1 {
            x[i] += s.xb2[i]
        }

        // println('n layer first update xb', l, x[0], x[1], x[2], x[3])
        // println('n layer first update fro s.xb2', l, s.xb2[0], s.xb2[1], s.xb2[2], s.xb2[3])
        // println('rmsnorm before xb', l, s.xb[0], s.xb[1], s.xb[2], s.xb[3])


        // ffn rmsnorm
        self.rmsnorm(s.xb, x, w.rms_ffn_weight as anyptr + (l * dim * @sizeof(f32)) as anyptr, dim)

        // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
        // first calculate self.w1(x) and self.w3(x)
        self.matmul(s.hb.ref(), s.xb.ref(), w.w1 as anyptr + (l * dim * hidden_dim * @sizeof(f32)) as anyptr, dim, hidden_dim)
        self.matmul(s.hb2.ref(), s.xb.ref(), w.w3 as anyptr + (l * dim * hidden_dim * @sizeof(f32)) as anyptr, dim, hidden_dim)

        // println('matmul after hb', l, s.hb[0], s.hb[1], s.hb[2], s.hb[3])
        // println('matmul after hb2', l, s.hb2[0], s.hb2[1], s.hb2[2], s.hb2[3], hidden_dim)

        // SwiGLU non-linearity
        for int i = 0; i < hidden_dim; i+=1 {
            f32 val = s.hb[i]
            // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
            val *= (1.0 / (1.0 + libc.expf(-val)))
            // elementwise multiply with w3(x)
            val *= s.hb2[i]
            s.hb[i] = val
        }

        // println('final matmul before hb', l, s.hb[0], s.hb[1], s.hb[2], s.hb[3])

        // final matmul to get the output of the ffn
        self.matmul(s.xb.ref(), s.hb.ref(), w.w2 as anyptr + (l * hidden_dim * dim * @sizeof(f32)) as anyptr, hidden_dim, dim)

        // println('final matmul after xb', l, s.xb[0], s.xb[1], s.xb[2], s.xb[3])


        // residual connection
        for int i = 0; i < dim; i+=1 {
            x[i] += s.xb[i]
        }

        // println('n layer second update', l, x[0], x[1], x[2], x[3])
    }

    // final rmsnorm
    self.rmsnorm(x, x, w.rms_final_weight as anyptr, dim)
    // println('x rmsnorm after', pos, s.logits[0], s.logits[1], s.logits[2], s.logits[3], x[0], x[1], x[2], x[3])


    // classifier into logits
    var temp_weight = unsafe.vec_new(w.wcls, 4)
    self.matmul(s.logits.ref(), x.ref(), w.wcls as anyptr, dim, p.vocab_size as int)
    // println('logits matmul after', pos, s.logits[0], s.logits[1], s.logits[2], s.logits[3])


    // log('[forward] end forward, matmul end', s.logits.len())

    return s.logits
}

// read_checkpoint
fn transformer_new(string path):ptr<transformer>! {
    var f = fs.open(path, syscall.O_RDONLY, 0)

    // 读取配置
    var config = config_t{}
    var config_buf = vec_new<u8>(0, @sizeof(config_t))
    assert(f.read(config_buf) == @sizeof(config_t))
    mem.copy(config_buf, &config)

    // 检查是否共享分类器权重
    u8 shared_weights = 0
    if config.vocab_size > 0 {
        shared_weights = 1
    } else {
        // abs vocab_size
        config.vocab_size = -config.vocab_size
    }

    // 获取文件大小
    var st = f.stat()
    u64 file_size = st.size

    // 内存映射整个文件
    anyptr data = libc.mmap(0, file_size as int, libc.PROT_READ, libc.MAP_PRIVATE, f.fd, 0)
    assert(data > 0)

    // 权重指针从配置结构之后开始
    anyptr weights_ptr = data as anyptr + @sizeof(config_t)

    var t = new transformer(config, fd = f.fd, data = data as ptr<f32>, file_size)

    t.memory_map_weights(weights_ptr, shared_weights)

    // 分配 RunState 缓冲区
    t.malloc_run_state(&t.state, &t.config)
    // println('[transformer_new] malloc run state', t.state.x.len(), t.state.xb.len(), t.state.xb2.len(), t.state.hb.len(),
    //    t.state.hb2.len(), t.state.q.len(), t.state.k.len(), t.state.v.len(), t.state.att.len(), t.state.logits.len(),
    //     t.state.key_cache.len(), t.state.value_cache.len())

    // println(t.state.logits.len(), t.state.logits[0], t.state.logits[1])

    return t
}

fn tokenizer_t.decode(int prev_token, int token):string! {
    string piece = self.vocab[token]
    // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
    if prev_token == 1 && piece[0] == ' '.char() {
        piece = piece.slice(1, piece.len())
    }

    u8 byte_val = 0
    if fmt.sscanf(piece, "<0x%02x>", &byte_val) == 1 {
        return [byte_val] as string
    }

    return piece
}

fn str_lookup(string str, [token_index_t] sorted_vocab, int vocab_size):int {
    int idx = sorted_vocab.search(fn(int i):bool {
        return sorted_vocab[i].str >= str
    }) catch e {
        return -1
    }

    if idx >= sorted_vocab.len() || sorted_vocab[idx].str != str {
        return -1
    }

    return sorted_vocab[idx].id
}

fn tokenizer_t.encode(string text, i8 bos, i8 eos, [int] tokens):int! {
    if self.sorted_vocab.len() == 0 {
        self.sorted_vocab = vec_cap<token_index_t>(self.vocab_size)
        for int i = 0; i < self.vocab_size; i+=1 {
            self.sorted_vocab.push(token_index_t{
                str: self.vocab[i],
                id: i,
            })
        }
        assert(self.sorted_vocab.len() == self.vocab_size)

        self.sorted_vocab.sort(fn(int a, int b):bool {
            return self.sorted_vocab[a].str < self.sorted_vocab[b].str
        })
    }

    int test_id = str_lookup("h", self.sorted_vocab, self.vocab_size)

    // start at 0 tokens
    int n_tokens = 0

    // add optional BOS (=1) token, if desired
    if bos != 0 {
        tokens[n_tokens] = 1
        n_tokens += 1
    }

    // add_dummy_prefix is true by default
    // so prepend a dummy prefix token to the input string, but only if text != ""
    // TODO: pretty sure this isn't correct in the general case but I don't have the
    // energy to read more of the sentencepiece code to figure out what it's doing
    if text.len() > 0 {
        int dummy_prefix = str_lookup(" ", self.sorted_vocab, self.vocab_size)

        tokens[n_tokens] = dummy_prefix
        n_tokens += 1
    }

    int len = text.len()

    // create a temporary buffer that will store merge candidates of always two consecutive tokens
    // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
    var str_buffer = vec_new<u8>(0, self.max_token_length as int * 2 + 1 + 2)
    int str_len = 0

    // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
    // Code point ↔ UTF-8 conversion
    // First code point	Last code point	Byte 1	Byte 2	Byte 3	Byte 4
    // U+0000	U+007F	    0xxxxxxx
    // U+0080	U+07FF	    110xxxxx	10xxxxxx
    // U+0800	U+FFFF	    1110xxxx	10xxxxxx	10xxxxxx
    // U+10000	U+10FFFF    11110xxx	10xxxxxx	10xxxxxx	10xxxxxx
    for int i = 0; i < len; i+=1 {
        u8 c = text[i]

        // reset buffer if the current byte is ASCII or a leading byte
        // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
        // 0x80 is 10000000
        // in UTF-8, all continuation bytes start with "10" in first two bits
        // so in English this is: "if this byte is not a continuation byte"
        if (c & 0xC0) != 0x80 {
            str_len = 0
        }

        // 将当前字节附加到缓冲区
        str_buffer[str_len] = c
        str_len += 1

        // utf8 handle
        if i < len - 1 && (text[i+1] & 0xC0) == 0x80 && str_len < 4 {
            continue
        }

        int id = str_lookup(str_buffer.slice(0, str_len) as string, self.sorted_vocab, self.vocab_size)

        if id != -1 {
            // we found this codepoint in vocab, add it as a token
            tokens[n_tokens] = id
            n_tokens += 1
        } else {
             // byte_fallback encoding: just encode each byte as a token
            // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
            // so the individual bytes only start at index 3
            for int j = 0; j < str_len; j+=1 {
                tokens[n_tokens] = (str_buffer[j] as int) + 3
                n_tokens += 1
            }
        }

        str_len = 0 // protect against a sequence of stray UTF8 continuation bytes
    }

    // merge the best consecutive pair each iteration, according the scores in vocab_scores
    for true {
        f32 best_score = -1e10
        int best_id = -1
        int best_idx = -1

        for int i = 0; i < (n_tokens - 1); i+=1 {
            // check if we can merge the pair (tokens[i], tokens[i+1])
            string merge_str = self.vocab[tokens[i]] + self.vocab[tokens[i+1]]
            int id = str_lookup(merge_str, self.sorted_vocab, self.vocab_size)

            if id != -1 && self.vocab_scores[id] > best_score {
                // this merge pair exists in vocab! record its score and position
                best_score = self.vocab_scores[id]
                best_id = id
                best_idx = i
            }
        }

        if best_idx == -1 {
            break // we couldn't find any more pairs to merge, so we're done
        }

        // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
        tokens[best_idx] = best_id

        // delete token at position best_idx+1, shift the entire sequence back 1
        for int i = best_idx + 1; i < n_tokens - 1; i+=1 {
            tokens[i] = tokens[i+1]
        }
        n_tokens -= 1 // token length decreased
    }

    // add optional EOS (=2) token, if desired
    if eos != 0 {
        tokens[n_tokens] = 2
        n_tokens += 1
    }

    return n_tokens
}

fn tokenizer_new(string tokenizer_path, int vocab_size):ptr<tokenizer_t>! {
    var t = new tokenizer_t(
        vocab_size = vocab_size,
        vocab = vec_cap<string>(0),
        vocab_scores = vec_cap<f32>(0),
        sorted_vocab = vec_cap<token_index_t>(0),
    )

    for int i = 0; i < 256; i+=1 {
        t.byte_pieces[i * 2] = i as u8
        t.byte_pieces[i * 2 + 1] = 0
    }

    var f = fs.open(tokenizer_path, syscall.O_RDONLY, 0)
    var file_reader = new buf.reader<ptr<fs.file_t>>(rd = f)


    var i32_buf = vec_new<u8>(0, 4)
    file_reader.read_exact(i32_buf)
    t.max_token_length = mem.read_u32_le(i32_buf)

    var f32_buf = vec_new<u8>(0, 4)
    for int i = 0; i < vocab_size; i+=1 {
        file_reader.read_exact(f32_buf)
        t.vocab_scores.push(mem.read_f32_le(f32_buf))

        file_reader.read_exact(i32_buf)
        int len = mem.read_i32_le(i32_buf) as int

        var str_buf = vec_new<u8>(0, len)
        file_reader.read_exact(str_buf)

        t.vocab.push(str_buf as string)
    }

    f.close()
    return t
}

fn random_u32(rawptr<u64> state):u32 {
    // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
    *state ^= *state >> 12
    *state ^= *state << 25
    *state ^= *state >> 27
    return ((*state * 0x2545F4914F6CDD1D) >> 32) as u32
}

fn random_f32(rawptr<u64> state):f32 {
    // random float32 in [0,1)
    return (random_u32(state) >> 8) as f32 / 16777216.0
}

fn sampler_t.sample_argmax([f32] probabilities, int n):int {
    // return the index that has the highest probability
    int max_i = 0
    f32 max_p = probabilities[0]
    for int i = 1; i < n; i += 1 {
        if probabilities[i] > max_p {
            max_i = i
            max_p = probabilities[i]
        }
    }
    return max_i
}

fn sampler_t.sample_mult([f32] probabilities, int n, f32 coin):int {
    // sample index from probabilities (they must sum to 1!)
    f32 cdf = 0.0
    for int i = 0; i < n; i += 1 {
        cdf += probabilities[i]
        if coin < cdf {
            return i
        }
    }

    return n - 1 // in case of rounding errors
}

fn sampler_t.sample_topp([f32] probabilities, int n, f32 topp, vec<prob_index_t> probindex, f32 coin):int {
    // top-p 采样(或称"nucleus sampling")从概率总和超过 topp 的最小 token 集合中采样
    // 这样我们就不会采样到概率非常低的 token，从而减少生成"离谱"内容的可能性
    
    // 首先，将概率值和对应索引填充到 probindex 中
    // 只考虑概率大于 cutoff 的 token
    int n0 = 0
    f32 cutoff = (1.0 - topp) / (n - 1) as f32
    
    for int i = 0; i < n; i += 1 {
        if probabilities[i] >= cutoff {
            probindex[n0].index = i
            probindex[n0].prob = probabilities[i]
            n0 += 1
        }
    }

    // logf('[sample_topp] n %d, cutoff: %f,  probindex: %d->%f, %d->%f, %d->%f', n, cutoff,
        // probindex[0].index, probindex[0].prob,  probindex[1].index, probindex[1].prob,  probindex[2].index, probindex[2].prob)

    // 按概率降序排序
    probindex.slice(0, n0).sort(fn(int a, int b):bool {
        return probindex[a].prob > probindex[b].prob
    })

    // logf('[sample_topp] sorted, n %d, cutoff: %f,  probindex: %d->%f, %d->%f, %d->%f', n, cutoff,
          //  probindex[0].index, probindex[0].prob,  probindex[1].index, probindex[1].prob,  probindex[2].index, probindex[2].prob)

    // 截断列表，使累积概率超过 topp
    f32 cumulative_prob = 0.0
    int last_idx = n0 - 1  // 以防舍入误差，默认考虑所有元素

    for int i = 0; i < n0; i += 1 {
        cumulative_prob += probindex[i].prob
        if cumulative_prob > topp {
            last_idx = i
            break  // 通过包含 last_idx，我们已经超过了 topp
        }
    }



    // 从截断列表中采样
    f32 r = coin * cumulative_prob
    f32 cdf = 0.0

    // logf('[sample_topp] r %f, last_id %d, index %d, cumulative_prob: %f, coin: %f', r, last_idx, probindex[last_idx].index, cumulative_prob, coin)

    for int i = 0; i <= last_idx; i += 1 {
        cdf += probindex[i].prob
        if r < cdf {
            // logf('[sample_topp] r %f < cdf %f, i %d, index %d', r, cdf, i, probindex[i].index)
            return probindex[i].index
        }
    }

    return probindex[last_idx].index
}

fn sampler_t.sample([f32] logits):int! {
    int vocab_size = self.vocab_size

    //logf('[sample] vocab_size %d, topp %f, logits: %f, %f, %f, %f, rng_state: %d',
     //vocab_size, self.topp,logits[0], logits[1], logits[2], logits[3], self.rng_state)

    // 如果温度为0，直接返回最高概率的token
    if self.temperature == 0.0 {
        return self.sample_argmax(logits, vocab_size)
    }

    // 应用温度缩放
    for int i = 0; i < vocab_size; i += 1 {
        logits[i] /= self.temperature
    }

    // apply softmax to the logic
    softmax(logits, vocab_size)

    f32 coin = random_f32(&self.rng_state)
    // logf('[sample] coin %f from state %d', coin, self.rng_state)
    if self.topp <= 0 || self.topp >= 1 {
        return self.sample_mult(logits, vocab_size, coin)
    } else {
        return self.sample_topp(logits, vocab_size, self.topp, self.probindex, coin)
    }
}

fn sampler_new(int vocab_size, f32 temperature, f32 topp, u64 rng_seed):ptr<sampler_t>! {
    return new sampler_t(
        vocab_size,
        temperature,
        topp,
        rng_state = rng_seed,
        probindex = vec_new<prob_index_t>(prob_index_t{}, vocab_size),
    )
}

fn error_usage():void! {
    var str = 'Usage:   run <checkpoint> [options]\n'
    str += 'Example: run model.bin -n 256 -i "Once upon a time"\n'
    str += 'Options:\n'
    str += '  -t <float>  temperature in [0,inf], default 1.0\n'
    str += '  -p <float>  p value in top-p (nucleus) sampling in [0,1] default 0.9\n'
    str += '  -s <int>    random seed, default time(NULL)\n'
    str += '  -n <int>    number of steps to run for, default 256. 0 = max_seq_len\n'
    str += '  -i <string> input prompt\n'
    str += '  -z <string> optional path to custom tokenizer\n'
    str += '  -m <string> mode: generate|chat, default: generate\n'
    str += '  -y <string> (optional) system prompt in chat mode\n'

    print(str)
    syscall.exit(1)
}

fn print_safe(string piece) {
    if piece.len() == 0 {
        return
    }

    if piece[0] == 0 {
        return
    }

    if piece.len() > 1 && piece[1] == '\0'[0] {
        u8 byte_val = piece[0]
        if !(libc.isprint(byte_val) || libc.isspace(byte_val)) {
            return
        }
    }

    print(piece)
}

fn generate(ptr<transformer> tf, ptr<tokenizer_t> t, ptr<sampler_t> s, string prompt, int steps):void! {
    // log('[generate] start', prompt, steps)
    var prompt_tokens = vec_new<int>(0, (prompt.len() + 3) as int) // +3 for '\0', ?BOS, ?EOS

    int num_prompt_tokens = t.encode(prompt, 1, 0, prompt_tokens)
    if num_prompt_tokens < 1 {
        panic("Something went wrong, expect at least 1 prompt word token")
    }

    // start the main loop
    i64 start = 0  // used to time our code, only initialized after first iteration
    int next = 0   // will store the next token in the sequence
    int token = prompt_tokens[0] // kick off with the first token in the prompt
    int pos = 0    // position in the sequence
    for pos < steps {
        // forward the transformer to get logits for the next token
        var logits = tf.forward(token, pos)
        // println('\nlogis ->', pos, token, logits.len(), logits[0], logits[1], logits[2], logits[3], logits[4])

        // advance the state state machine
        if pos < num_prompt_tokens - 1 {
            // if we are still processing the input prompt, force the next prompt token
            next = prompt_tokens[pos + 1]
        } else {
            // otherwise sample the next token from the logits
            next = s.sample(logits)
        }
        pos += 1

        // data-dependent terminating condition: the BOS (=1) token delimits sequences
        if next == 1 {
            break
        }

        // print the token as string, decode it with the Tokenizer object
        var piece = t.decode(token, next)
        // log('[generate] tokenizer decode success', piece)

        print_safe(piece) // same as printf("%s", piece), but skips "unsafe" bytes

        token = next

        // init the timer here because the first iteration can be slower
        if start == 0 {
            start = time.now().ms_timestamp()
        }
    }

    print('\n')

    // Report tok/s implemented (pos-1 is because the timer is started after the first iteration)
    if pos > 1 {
        i64 end = time.now().ms_timestamp()
        fmt.printf("achieved tok/s: %f\n", (pos-1) as f64 / (end-start) as f64 * 1000.0)
    }
}

fn main():void! {
    // default parameters
    string checkpoint_path = ''
    string tokenizer_path = 'tokenizer.bin'
    f32 temperature = 1.0
    f32 topp = 0.9 // 默认使用top-p采样
    int steps = 256
    string prompt = ''
    u64 rng_seed = 0
    string mode = 'generate' // generate or chat
    string system_prompt = ''

    // read args
    var args = os.args()
    var args_len = args.len()
    if args_len >= 2 {
        checkpoint_path= args[1]
    } else {
        error_usage()
    }

    for int i = 2; i < args_len; i+=2 {
        if i + 1 >= args_len { error_usage() } // must have arg after flag
        if args[i][0] != '-'.char() { error_usage() } // must start with dash
        if args[i].len() != 2 { error_usage() }  // must be -x (one dash, one letter)

        match args[i][1] {
            't'.char() -> { temperature = args[i+1].to_float() as f32 }
            'p'.char() -> { topp = args[i+1].to_float() as f32 }
            's'.char() -> { rng_seed = args[i+1].to_int() as u64 }
            'n'.char() -> { steps = args[i+1].to_int() }
            'i'.char() -> { prompt = args[i+1] }
            'z'.char() -> { tokenizer_path = args[i+1] }
            'm'.char() -> { mode = args[i+1] }
            'y'.char() -> { system_prompt = args[i+1] }
            _ -> {
                error_usage()
            }
        }
    }

    // parameter validation
    if rng_seed <= 0 {
        rng_seed = time.unix() as u64
    }
    if temperature < 0.0 {
        temperature = 0.0
    }
    if topp < 0.0 || topp > 1.0 {
        topp = 0.0
    }
    if steps < 0 {
        steps = 0
    }

    // build the Transformer via the model .bin file
    var tf = transformer_new(checkpoint_path)

    // log('transformer_new success:', tf.config.dim, tf.config.hidden_dim, tf.config.n_layers,
        //tf.config.n_heads, tf.config.n_kv_heads, tf.config.vocab_size, tf.config.seq_len,
        //tf.fd, tf.file_size)

    if steps == 0 || steps > tf.config.seq_len as int {
        steps = tf.config.seq_len as int
    }

    // build the Tokenizer via the tokenizer .bin file
    var t = tokenizer_new(tokenizer_path, tf.config.vocab_size as int)

    // build the Sampler
    var s = sampler_new(tf.config.vocab_size as int, temperature, topp, rng_seed)

    // logf('[main] sampler_new successful, tf.state.x.len() = %d, xb.len()=%d', tf.state.x.len(), tf.state.xb.len())

    if mode == 'generate' {
        generate(tf, t, s, prompt, steps)
    } else {
        throw errorf('unknown mode: %s', mode)
    }
}
